import torch
import torchvision
import torch.nn as nn
import argparse
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
import time
#import matplotlib.pyplot as plt
from GD import GD
from SGD import SGD
from SGDwReg import SGDwReg
from DoubleSGD import DoubleSGD
from CreateCNN import CreateCNN
from CreateResNet import CreateResNet
from DataSet import Dataset
from MultiSGD import MultiSGD


parser = argparse.ArgumentParser()
parser.add_argument("--RunSGD", type=int, required=True)
parser.add_argument("--eps", type=int, required=True)
parser.add_argument("--RunSGDwReg", type=int, required=True)
parser.add_argument("--reg_eps", type=int, required=True)
parser.add_argument("--lr", type=float, required=True)
parser.add_argument("--decay_rate", type=float, required=True)
parser.add_argument("--weight_decay", type=float, required=True)
parser.add_argument("--bs", type=int, required=True)
parser.add_argument("--lambda1", type=float, required=True)
parser.add_argument("--torchseed", type=int, required=True)
args = parser.parse_args()

print(args)
print(args.eps)
RunSGD = bool(args.RunSGD)
eps = args.eps
RunSGDwReg = bool(args.RunSGDwReg)
reg_eps = args.reg_eps
learning_rate = args.lr
decay_rate = args.decay_rate
weight_decay = args.weight_decay
bs = args.bs
lambda1 = args.lambda1
torchseed = args.torchseed

#eps = 1000
#bs = 64
train_size = 60000
#learning_rate = 0.03
alpha = -0.1
B = 50000
#K = int(train_size / bs)
K = 20
RunGD = False
#RunSGD = True
#RunSGDwReg = True
RunDoubleSGD = False
RunMultiSGD = False
ModelArchitecture = 'CNN'
LoadModel = False
starting_epoch = 3200
ModelPath = 'Saved01/GD11/epoch{}.pt'.format(starting_epoch)

if __name__ == '__main__':
    starting_time = time.time()
    torch.manual_seed(torchseed)
    if LoadModel == False:
        starting_epoch = 0
        if ModelArchitecture == 'CNN':
            model = CreateCNN()
        elif ModelArchitecture == 'ResNet':
            model = CreateResNet()
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(name)
                    #print(name, param.data)
        else:
            print('The model you choose is not supported.')
    elif LoadModel == True:
        if ModelArchitecture == 'CNN':
            model = CreateCNN()
            model.load_state_dict(torch.load(ModelPath))
        elif ModelArchitecture == 'ResNet':
            model = CreateResNet()
            model.load_state_dict(torch.load(ModelPath))
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(name)
    train_data, test_data = Dataset(train_size)
    if RunGD == True:
        print(starting_epoch)
        GDEpochs, GDProductTraces, GDFrobeniuses, GDHessianTraces, GDTrainLosses, GDTestLosses, GDTestAccuracies = GD(model, train_data, test_data, train_size, bs, eps, learning_rate, decay_rate, starting_epoch=starting_epoch, K=K, B=B)
        print("GDEpochs")
        print(GDEpochs)
        print("GDProductTraces")
        print(GDProductTraces)
        print("GDFrobeniuses")
        print(GDFrobeniuses)
        print("GDHessianTraces")
        print(GDHessianTraces)
        print("GDTrainLosses")
        print(GDTrainLosses)
        print("GDTestLosses")
        print(GDTestLosses)
        print("GDTestAccuracies")
        print(GDTestAccuracies)

    if RunSGD == True:
        SGDEpochs, SGDProductTraces, SGDFrobeniuses, SGDHessianTraces, SGDTrainLosses, SGDTestLosses, SGDTestEpochs, SGDTestAccuracies  = SGD(model, train_data, test_data, train_size, bs, eps, learning_rate, decay_rate, weight_decay, starting_epoch=starting_epoch, B=B, seed=torchseed)
        print("SGDEpochs")
        print(SGDEpochs)
        print("SGDProductTraces")
        print(SGDProductTraces)
        print("SGDFrobeniuses")
        print(SGDFrobeniuses)
        print("SGDHessianTraces")
        print(SGDHessianTraces)
        print("SGDTrainLosses")
        print(SGDTrainLosses)
        print("SGDTestLosses")
        print(SGDTestLosses)
        print("SGDTestEpochs")
        print(SGDTestEpochs)
        print("SGDTestAccuracies")
        print(SGDTestAccuracies)

    if RunSGDwReg == True:
        SGDwRegEpochs, SGDwRegProductTraces, SGDwRegFrobeniuses, SGDwRegHessianTraces, SGDwRegTrainLosses, SGDwRegTestLosses, SGDwRegTestEpochs, SGDwRegTestAccuracies  = SGDwReg(model, train_data, test_data, lambda1, train_size, bs, reg_eps, learning_rate, decay_rate, weight_decay, starting_epoch=starting_epoch, B=B, seed=torchseed)
        print("SGDwRegEpochs")
        print(SGDwRegEpochs)
        print("SGDwRegProductTraces")
        print(SGDwRegProductTraces)
        print("SGDwRegFrobeniuses")
        print(SGDwRegFrobeniuses)
        print("SGDwRegHessianTraces")
        print(SGDwRegHessianTraces)
        print("SGDwRegTrainLosses")
        print(SGDwRegTrainLosses)
        print("SGDwRegTestLosses")
        print(SGDwRegTestLosses)
        print("SGDwRegTestEpochs")
        print(SGDwRegTestEpochs)
        print("SGDwRegTestAccuracies")
        print(SGDwRegTestAccuracies)

    if RunDoubleSGD == True:
        DoubleSGDEpochs, DoubleSGDProductTraces, DoubleSGDFrobeniuses, DoubleSGDHessianTraces, DoubleSGDTrainLosses, DoubleSGDTestLosses, DoubleSGDTestEpochs, DoubleSGDTestAccuracies  = DoubleSGD(model, train_data, test_data, train_size, bs, eps, learning_rate, alpha, starting_epoch=starting_epoch, B=B)
        print("DoubleSGDEpochs")
        print(DoubleSGDEpochs)
        print("DoubleSGDProductTraces")
        print(DoubleSGDProductTraces)
        print("DoubleSGDFrobeniuses")
        print(DoubleSGDFrobeniuses)
        print("DoubleSGDHessianTraces")
        print(DoubleSGDHessianTraces)
        print("DoubleSGDTrainLosses")
        print(DoubleSGDTrainLosses)
        print("DoubleSGDTestLosses")
        print(DoubleSGDTestLosses)
        print("DoubleSGDTestEpochs")
        print(DoubleSGDTestEpochs)
        print("DoubleSGDTestAccuracies")
        print(DoubleSGDTestAccuracies)

    if RunMultiSGD == True:
        SGDEpochs, SGDProductTraces1, SGDFrobeniuses1, SGDHessianTraces1, \
        SGDProductTraces2, SGDFrobeniuses2, SGDHessianTraces2, \
        SGDProductTraces3, SGDFrobeniuses3, SGDHessianTraces3, SGDTrainLosses = MultiSGD(train_data, test_data,
                                                                                            train_size, bs, eps)
        print("SGDEpochs")
        print(SGDEpochs)
        print("SGDProductTraces1")
        print(SGDProductTraces1)
        print("SGDFrobeniuses1")
        print(SGDFrobeniuses1)
        print("SGDHessianTraces1")
        print(SGDHessianTraces1)
        print("SGDProductTraces2")
        print(SGDProductTraces2)
        print("SGDFrobeniuses2")
        print(SGDFrobeniuses2)
        print("SGDHessianTraces2")
        print(SGDHessianTraces2)
        print("SGDProductTraces3")
        print(SGDProductTraces3)
        print("SGDFrobeniuses3")
        print(SGDFrobeniuses3)
        print("SGDHessianTraces3")
        print(SGDHessianTraces3)
        print("SGDTrainLosses")
        print(SGDTrainLosses)


    '''
    plt.plot(GDEpochs, GDProductTraces, color='b', label="GD")
    plt.plot(GDEpochs, SGDProductTraces, color='r', label="SGD")
    #for run in range(1, RunNum):
    #    plt.plot(GDEpochs[0], GDProductTraces[run], color='b')
    #    plt.plot(GDEpochs[0], SGDProductTraces[run], color='r')
    # plt.yscale('log')
    plt.legend()
    plt.title("Trace of Hessian and Covariance Matrix product of GD and SGD on MNIST")
    plt.xlabel("Epochs")
    plt.ylabel("Trace")
    plt.savefig("Trace")
    plt.close()

    plt.plot(GDEpochs, GDFrobeniuses, color='b', label="GD")
    plt.plot(GDEpochs, SGDFrobeniuses, color='r', label="SGD")
    #for run in range(1, RunNum):
    #    plt.plot(GDEpochs[0], GDFrobeniuses[run], color='b')
    #    plt.plot(GDEpochs[0], SGDFrobeniuses[run], color='r')
    # plt.yscale('log')
    plt.legend()
    plt.title("Frobenius norm of Hessian of GD and SGD on MNIST")
    plt.xlabel("Epochs")
    plt.ylabel("Frobenius norm")
    plt.savefig("Frobenius")
    plt.close()

    plt.plot(GDEpochs, GDHessianTraces, color='b', label="GD")
    plt.plot(GDEpochs, SGDHessianTraces, color='r', label="SGD")
    #for run in range(1, RunNum):
    #    plt.plot(GDEpochs[0], GDHessianTraces[run], color='b')
    #    plt.plot(GDEpochs[0], SGDHessianTraces[run], color='r')
    # plt.yscale('log')
    plt.legend()
    plt.title("Trace of Hessian of GD and SGD on MNIST")
    plt.xlabel("Epochs")
    plt.ylabel("Trace of Hessian")
    plt.savefig("Hessian")
    plt.close()
    '''

    print("Run time is {}".format(time.time() - starting_time))
